# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
from pathlib import Path

from pydantic import BaseModel


class FitConfig(BaseModel):
    """
    Configuration parameters for linear fit and outlier detection.
    """
    # Threshold for small concurrency range (≤ 8 points) to check for extreme outliers in raw y-values first
    small_concurrency_range_threshold: int = 8

    # Extreme outlier threshold is 2.0 times the IQR, extreme outliers are removed
    extreme_outlier_threshold: float = 2.0

    # Conservative outlier threshold is 1.5 times the IQR, conservative outliers are removed
    conservative_outlier_threshold: float = 1.5

    # Minimum R-squared value required for a valid linear fit
    min_r_squared: float = 0.7

    # Whether to remove outliers during linear fit calculation
    remove_outliers: bool = True


class CalcRunnerConfig(BaseModel):
    """
    Parameters used for a calc runner.
    """
    # base config and endpoints (if remote)- not needed in offline mode
    config_file: Path | None = None
    # endpoint to use for the workflow, if not provided the workflow is run locally
    endpoint: str | None = None
    # timeout for the workflow
    endpoint_timeout: int = 300

    # if true workflow is not run, instead results from previous runs are used to estimate the
    # GPU count
    offline_mode: bool = False

    # number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
    # concurrency
    num_passes: int = 0
    # concurrency values to test
    concurrencies: list[int] = [1, 2, 4, 8]

    # Targets for GPU estimation
    target_llm_latency_p95: float = 0
    target_workflow_runtime_p95: float = 0
    target_users: int = 0

    # Test setup information needed for GPU estimation
    test_gpu_count: int = 0

    # output directory for results
    output_dir: Path | None = None
    # if true, the job is stored in a new subdirectory of the output directory
    append_job: bool = False
    # if true, the data is plotted
    plot_data: bool = True

    # Configuration for linear fit and outlier detection
    fit_config: FitConfig = FitConfig()


# Sizing metrics are gathered from the evaluation runs and used as input by the calculator.
class SizingMetricPerItem(BaseModel):
    """
    Sizing metrics per dataset entry item.
    """
    # LLM latency
    llm_latency: float
    # workflow runtime
    workflow_runtime: float


class SizingMetricsAlerts(BaseModel):
    """
    Sizing metrics alerts.
    """
    # if true, the workflow was interrupted that concurrency cannot be used
    workflow_interrupted: bool = False


class SizingMetrics(BaseModel):
    """
    Sizing metrics for a single concurrency.
    """
    # alerts associated with the sizing metrics
    alerts: SizingMetricsAlerts = SizingMetricsAlerts()

    # p95 LLM latency
    llm_latency_p95: float = 0.0
    # p95 workflow runtime
    workflow_runtime_p95: float = 0.0
    # total workflow runtime
    total_runtime: float = 0.0
    # per item metrics, key is the dataset entry id
    per_item_metrics: dict[typing.Any, SizingMetricPerItem] = {}


class LinearFitResult(BaseModel):
    """
    Result of linear regression including slope, intercept, and quality metrics.
    """
    slope: float
    intercept: float
    r_squared: float
    outliers_removed: list[int]


# GPU estimates are generated by the calculator.
class GPUEstimates(BaseModel):
    """
    GPU estimates.
    """
    # GPU estimate based on the workflow runtime
    gpu_estimate_by_wf_runtime: float | None = None
    # GPU estimate based on the LLM latency
    gpu_estimate_by_llm_latency: float | None = None


# Calc runner alerts are generated by the calculator.
class CalcAlerts(BaseModel):
    """
    Calc runner alerts.
    """
    # if true, the run was identified as an outlier by the workflow runtime linear fit
    outlier_workflow_runtime: bool = False
    # if true, the run was identified as an outlier by the LLM latency linear fit
    outlier_llm_latency: bool = False

    # number of items that are greater than the target latency
    num_items_greater_than_target_latency: int = 0
    # number of items that are greater than the target runtime
    num_items_greater_than_target_runtime: int = 0


class CalcData(BaseModel):
    """
    Output of the calc runner per concurrency.
    """
    # ROUGH GPU estimates per concurrency: these are not used for the final GPU estimation
    # they are only available for information purposes
    gpu_estimates: GPUEstimates = GPUEstimates()
    # Calc runner alerts
    alerts: CalcAlerts = CalcAlerts()
    # Sizing metrics
    sizing_metrics: SizingMetrics = SizingMetrics()


class CalcRunnerOutput(BaseModel):
    """
    Output of the calc runner.
    """
    # GPU estimates based on the slope of the time vs concurrency, calculated online or offline
    gpu_estimates: GPUEstimates

    # Per-concurrency data (GPU estimates, out-of-range runs, and sizing metrics)
    calc_data: dict[int, CalcData] = {}
